from copy import deepcopy
import os
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
from tensorboardX import SummaryWriter
import argparse

torch.set_default_dtype(torch.DoubleTensor)

parser = argparse.ArgumentParser(description='Numerical Results')
parser.add_argument('--p', default=512, type=int, help='the dimension of features and prototypes')
parser.add_argument('--num-classes', default=100, type=int, help='the number of classes')
parser.add_argument('--num-per-class', default=10, type=int, help='the number of sample in each class')
parser.add_argument('--lamb', default=None, type=float, help='the regularization coefficient (weight decay)')
parser.add_argument('--gamma', default=None, type=float, help='the trade-off parameter betweeen positive and negative logits')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--s', default=1.0, type=float, help='the scale parameter of learning rate')
parser.add_argument('--exp', default='exp', type=str, help='the experiment name')
parser.add_argument('--mode', default='regularized', choices=['unconstrained', 'regularized'], help='')
parser.add_argument('--seed', default=123, type=int, help='random seed')

args = parser.parse_args()


os.environ['CUDA_VISIBLE_DEVICES'] = '0'
torch.backends.cudnn.enabled =True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.determinstic = True
device = 'cuda' if torch.cuda.is_available() else 'cpu'

random.seed(args.seed)
np.random.seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.manual_seed(args.seed)



class Loss(nn.Module):
    def __init__(self, gamma=0.0):
        super().__init__()
        self.gamma = gamma
    
    def forward(self, logits, labels):
        label_one_hot = F.one_hot(labels, logits.size()[1]).float().to(logits.device)
        l1 = torch.sum(logits * label_one_hot, dim=-1)
        l2 = torch.sum(logits * (1 - label_one_hot), dim=-1)
        loss = -l1 + self.gamma * l2
        return loss.mean()


def evaluate(out, labels):
    pred = torch.argmax(out, 1)
    total = labels.size(0)
    correct = (pred==labels).sum().item()
    acc = float(correct) / float(total)
    return acc


def get_margin(weight):
    tmp = F.normalize(weight, dim=1)
    similarity = torch.matmul(tmp, tmp.transpose(1, 0)) - 2 * torch.eye(tmp.size(0), device=weight.device)
    similarity = torch.clamp(similarity, -1+1e-7, 1-1e-7)
    return torch.acos(torch.max(similarity)).item() / math.pi * 180


def projection1(H, W, labels, neg=False):
    P = deepcopy(W)
    eps = -1 if neg else 1
    for i in range(C):
        index = torch.where(labels==i)
        P[i] = P[i] + eps / np.sqrt(N) * H[index].sum(dim=0)

    P = P - torch.mean(P, dim=0, keepdim=True)
    return P / 2


def w1(sig, lam, s):
    return -0.5 * (math.sqrt(lam**2*(s-1)**2+4*s*sig**2)+lam*(s+1))

def w2(sig, lam, s):
    return 0.5 * (math.sqrt(lam**2*(s-1)**2+4*s*sig**2)-lam*(s+1))

def wh(lam, gamma, s, C, N, neg=False):
    eps = -1 if neg else 1
    tmp = (1 + gamma) / (C * math.sqrt(N))
    return (eps * s * tmp - s * lam - w1(tmp, lam, s)) / (w2(tmp, lam, s) - w1(tmp, lam, s))

def ww(lam, gamma, s, C, N, neg=False):
    eps = -1 if neg else 1
    tmp = (1 + gamma) / (C * math.sqrt(N))
    return (s * lam + w2(tmp, lam, s) + eps * tmp) / (w2(tmp, lam, s) - w1(tmp, lam, s))




C = args.num_classes
N = args.num_per_class
p = args.p


labels = [i for i in range(C)] * N
labels = torch.LongTensor(labels).to(device)
H = torch.randn(C * N, p).to(device)
W = torch.randn(C, p).to(device)
nn.init.kaiming_uniform_(W)

P1_pos = projection1(H, W, labels)
P1_neg = projection1(H, W, labels, neg=True)

H.requires_grad = True
W.requires_grad = True

s = args.s
lr_w = args.lr
lr_h = s * lr_w
gamma = 1. / (C - 1) * 0.5
lamb = (1 + gamma) / (C * math.sqrt(N)) if args.lamb is None else args.lamb

pi_h_pos = wh(lamb, gamma, s, C, N)
pi_h_neg = wh(lamb, gamma, s, C, N, neg=True)

pi_w_pos = ww(lamb, gamma, s, C, N)
pi_w_neg = ww(lamb, gamma, s, C, N, neg=True)


H_star = 1. / math.sqrt(N) * (pi_h_pos * P1_pos.repeat((N, 1)) - pi_h_neg * P1_neg.repeat((N, 1)))
W_star = pi_w_pos * P1_pos + pi_w_neg * P1_neg


optimizer = torch.optim.SGD(
[
    {'params': H, 'lr': lr_h},
    {'params': W, 'lr': lr_w},
],
weight_decay=lamb)

criterion = Loss(gamma=gamma)
store_name = './log/' + args.mode + '/' + args.exp + '/dim={}, C={}, N={}, lambda={}, gamma={}, lr={}, scale={}'.format(p, C, N, lamb, gamma, lr_w, s)
tf_writer = SummaryWriter(log_dir=store_name)

acc_list = []
m_list = []
norm_w_list = []
norm_h_list = []
error_h_list = []
error_w_list = []

epochs = 50000

for ep in range(epochs):
    out = F.linear(H, W)
    loss = criterion(out, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    acc = evaluate(out, labels)
    margin = get_margin(W)
    norm_w = torch.norm(W)
    norm_h = torch.norm(H)

    error_h = torch.sum((H / torch.norm(H) - H_star / torch.norm(H_star))**2).item()
    error_w = torch.sum((W / torch.norm(W) - W_star / torch.norm(W_star))**2).item()

    acc_list.append(acc)
    m_list.append(margin)
    norm_w_list.append(norm_w)
    norm_h_list.append(norm_h)
    error_h_list.append(error_h)
    error_w_list.append(error_w)


    tf_writer.add_scalar('acc', acc, ep)
    tf_writer.add_scalar('margin', margin, ep)
    tf_writer.add_scalar('W', norm_w, ep)
    tf_writer.add_scalar('H', norm_h, ep)
    tf_writer.add_scalar('err_h', error_h, ep)
    tf_writer.add_scalar('err_w', error_w, ep)

    if ep % 200 ==0:
        print('Iter {}: loss={:.4f}, acc={:.4f}, margin={:.4f}, norm_w={:.4f}, norm_f={:.4f}, error_h={:.4f}, error_w={:.4f}'.format(ep, loss.item(), acc, margin, norm_w, norm_h, error_h, error_w))
    torch.cuda.empty_cache()


acc_list = np.array(acc_list)
m_list = np.array(m_list)
norm_w_list = np.array(norm_w_list)
norm_h_list = np.array(norm_h_list)
error_h_list = np.array(error_h_list)
error_w_list = np.array(error_w_list)

np.save(store_name+'/acc.npy', acc_list)
np.save(store_name+'/margin.npy', m_list)
np.save(store_name+'/norm_w.npy', norm_w_list)
np.save(store_name+'/norm_h.npy', norm_h_list)
np.save(store_name+'/error_w.npy', error_w_list)
np.save(store_name+'/error_h.npy', error_h_list)